import pandas as pd
import os
import math
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
# plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

show_cfg = {
    "init_cash":        {"cname": "初始金额", "round": 0, "%": False, "show": True},
    "trade_pnl":        {"cname": "净利润", "round": 0, "%": False, "show": True},
    "trade_comm":       {"cname": "手续费", "round": 0, "%": False, "show": True},
    "total_ret":        {"cname": "累计收益", "round": 0, "%": True, "show": True},
    "mdd":              {"cname": "最大回撤", "round": 0, "%": True, "show": True},
    "mddlen":           {"cname": "回撤周期", "round": 0, "%": False, "show": True},
    "win_rate":         {"cname": "胜率",     "round": 0, "%": True, "show": True},
    "trade_count":      {"cname": "交易数",   "round": 0, "%": False, "show": True},
    "profit_factor":    {"cname": "盈利因子", "round": 2, "%": False, "show": True},
    "profit_expect":    {"cname": "预期回报", "round": 2, "%": False, "show": True},
    "win_loss_ratio":   {"cname": "盈亏比", "round": 2, "%": False, "show": True},
    "rets_risk_ratio":  {"cname": "收益风险比", "round": 2, "%": False, "show": True},
    "worse_ret":        {"cname": "最坏亏损", "round": 2, "%": True, "show": True},
    "sharp":            {"cname": "Sharp", "round": 2, "%": False, "show": True},
    "%year":            {"cname": "20", "round": 2, "%": True, "show": True},
    "avg_win":          {"cname": "平均盈利", "round": 0, "%": False, "show": False},
    "avg_loss":         {"cname": "平均亏损", "round": 0, "%": False, "show": False},
    "annual_ret":       {"cname": "年化收益", "round": 0, "%": True, "show": False},
    "worse_loss":       {"cname": "最大亏损", "round": 0, "%": False, "show": False},
    "sqn":              {"cname": "SQN", "round": 2, "%": False, "show": False},
    "win_pnl":          {"cname": "总盈利", "round": 0, "%": False, "show": False},
    "lost_pnl":         {"cname": "总亏损", "round": 0, "%": False, "show": False},
    "win_longest":      {"cname": "最长盈利", "round": 0, "%": False, "show": False},
    "lost_longest":     {"cname": "最长亏损", "round": 0, "%": False, "show": False},
    "win_longest_pnl":  {"cname": "连续盈利", "round": 0, "%": False, "show": False},
    "lost_longest_pnl": {"cname": "连续亏损", "round": 0, "%": False, "show": False},
    "best_win":         {"cname": "最大盈利", "round": 0, "%": False, "show": False},
    "best_ret":         {"cname": "最大收益率", "round": 2, "%": True, "show": False},
    "longest_held":     {"cname": "最长持仓", "round": 0, "%": False, "show": False},
    "avg_held":         {"cname": "平均持仓", "round": 0, "%": False, "show": False},
    "best_month_ret":   {"cname": "最佳月收益", "round": 0, "%": True, "show": False},
    "worst_month_ret":  {"cname": "最坏月收益", "round": 0, "%": True, "show": False},
    "avg_month_ret":    {"cname": "平均月收益", "round": 0, "%": True, "show": False},
    "win_month_count":  {"cname": "盈利月份", "round": 0, "%": False, "show": False},
    "lost_month_count": {"cname": "亏损月份", "round": 0, "%": False, "show": False}
}


class ShowPlot:
    def __init__(self, fdir, symbol, benchmark, strats=[], show_fig=True):
        self.fdir = fdir
        self.symbol = symbol
        self.benchmark = benchmark
        self.strats = strats
        self.show_fig = show_fig

    def show(self, name, figsize=(30, 15), key_word=[], start=None, end=None, align=True):
        symbol_name = self.symbol.replace("/", "_")


        out = None

        strats_list = OrderedDict()
        strats_df = []

        index = 0
        # 先过滤出需要额策略名称
        for strat in self.strats:
            _strat_path = f"{self.fdir}/{strat}/{symbol_name}"
            strat_name_list0000 = [name] if name else [e for e in sorted(os.listdir(_strat_path)) if os.path.isdir(os.path.join(_strat_path, e))]

            strat_name_list = []
            if key_word:
                for _key in key_word:
                    for _name in strat_name_list0000:
                        if _key in _name and _key not in strat_name_list:
                            strat_name_list.append(_name)
            else:
                strat_name_list = strat_name_list0000


            strats_list[strat] = []
            for strat_name in strat_name_list:
                strat_path = os.path.join(_strat_path, strat_name)
                if not os.path.exists(strat_path):
                    continue
                strats_list[strat].append(strat_name)

        # 检查这些策略的时间、初始资金是否一样
        _items = []
        for strat, strat_names in strats_list.items():
            for strat_name in strat_names:
                strat_path = os.path.join(self.fdir, strat, symbol_name, strat_name)
                _df = pd.read_csv(os.path.join(strat_path, "report.csv"))
                del _df["symbol"]
                keys = _df.set_index("ind").to_dict()["value"]
                _start_date = int(keys["start_date"])
                _end_date = int(keys["ended_date"])
                _cash = int(keys["init_cash"])
                print("[%40s-%10s-%40s]: {%s-%s, %s}" % (strat, symbol_name, strat_name, _start_date, _end_date, _cash))
                _items.append("%s_%s_%s" % (_cash, _start_date, _end_date))
        if not _items:
            raise SystemError(f"No Found match strategy.")

        if len(set(_items)) > 1:
            raise SystemError(f"Found the strategy start date and end date are different.")

        # tmp = _items[0].split("_")
        # start_date, end_date, init_cash = tmp[0], tmp[1], tmp[2]

        # 加载benchmark
        _benchmark_path = f"{self.fdir}/{self.benchmark}/{symbol_name}"
        _dd = [e for e in os.listdir(_benchmark_path) if os.path.isdir(os.path.join(_benchmark_path, e)) and e.endswith(_items[0])]
        benchmark_path = os.path.join(_benchmark_path, _dd[0]) if _dd else None
        if not benchmark_path:
            raise FileNotFoundError(f"{_benchmark_path} has not benchmark report.csv")

        benchmark_df = pd.read_csv(os.path.join(benchmark_path, "report.csv"))

        del benchmark_df["symbol"]
        benchmark_df.rename(columns={"value": "benchmark"}, inplace=True)

        benchmark_nav_df = pd.read_csv(os.path.join(benchmark_path, "nav.csv"))
        if start:
            benchmark_nav_df = benchmark_nav_df[benchmark_nav_df["date"] >= start]

        if end:
            benchmark_nav_df = benchmark_nav_df[benchmark_nav_df["date"] <= end]

        if align and len(benchmark_nav_df) > 0:
            benchmark_nav_df["value"] = benchmark_nav_df["value"] / benchmark_nav_df.iloc[0]["value"]

        del benchmark_nav_df["symbol"], benchmark_nav_df["cash"]
        benchmark_nav_df.rename(columns={"value": self.benchmark}, inplace=True)

        merge_df = benchmark_df.copy()
        merge_nav_df = benchmark_nav_df.copy()


        # 合并数据
        for strat, strat_names in strats_list.items():
            for strat_name in strat_names:
                index += 1
                _path = os.path.join(self.fdir, strat, symbol_name, strat_name)
                out = _path
                strat_nav_df = pd.read_csv(os.path.join(_path, "nav.csv"))

                if start:
                    strat_nav_df = strat_nav_df[strat_nav_df["date"] >= start]

                if end:
                    strat_nav_df = strat_nav_df[strat_nav_df["date"] <= end]

                if align and len(strat_nav_df) > 0:
                    strat_nav_df["value"] = strat_nav_df["value"] / strat_nav_df.iloc[0]["value"]

                del strat_nav_df["symbol"], strat_nav_df["cash"]
                strat_nav_df.rename(columns={"value": f"{strat}_{strat_name}[{index}]"}, inplace=True)

                if benchmark_nav_df.iloc[0]["date"] != strat_nav_df.iloc[0]["date"] and benchmark_nav_df.iloc[-1]["date"] != strat_nav_df.iloc[-1]["date"]:
                    raise SystemError(f"benchmark {self.benchmark} date duration is not different with strategy {strat}/{strat_name}")
                merge_nav_df[f"{strat}_{strat_name}[{index}]"] = strat_nav_df[f"{strat}_{strat_name}[{index}]"]

                strat_df = pd.read_csv(os.path.join(_path, "report.csv"))
                del strat_df["symbol"]
                strat_df.rename(columns={"value": f"strat[{index}]"}, inplace=True)
                merge_df[f"strat[{index}]"] = strat_df[f"strat[{index}]"]

        _show_cfg = show_cfg.copy()
        # 格式化数据
        merge_df = merge_df.set_index("ind").T
        cols = [k for k, v in _show_cfg.items() if v["show"] and k != "%year"]
        if "%year" in _show_cfg and _show_cfg["%year"]["show"]:
            year_cols = [e for e in merge_df.columns if _show_cfg["%year"]["cname"] in e]
            cols += year_cols
            for c in year_cols:
                _tmp = _show_cfg["%year"].copy()
                _tmp["cname"] = c
                _show_cfg[c] = _tmp

        merge_df = merge_df[cols]
        # 百分比显示的乘以100
        for col in cols:
            if _show_cfg[col]["%"]:
                merge_df[col] = merge_df[col] * 100

        # 小数点
        round_cfg = {col: _show_cfg[col]["round"] for col in cols}
        merge_df = merge_df.round(round_cfg)

        # 把小数点为0转换成整数
        type_cfg = {col: int for col in cols if _show_cfg[col]["round"] == 0}
        merge_df = merge_df.astype(type_cfg)

        # 百分比显示
        for col in cols:
            if _show_cfg[col]["%"]:
                merge_df[col] = merge_df[col].astype(str) + "%"

        merge_df = merge_df.fillna('-')
        merge_df = merge_df.replace('nan%', '-')

        # 中文显示
        merge_df.columns = [_show_cfg[e]["cname"] for e in cols]

        # 开始画图
        merge_nav_df["date"] = pd.to_datetime(merge_nav_df["date"])
        merge_nav_df = merge_nav_df.set_index("date")

        # 画净值图
        plt.figure(figsize=figsize)
        plt.subplot2grid((6, 1), (0, 0), rowspan=4)
        plt.plot(merge_nav_df)
        plt.legend(labels=merge_nav_df.columns)
        plt.grid()

        # 画表格图
        plt.subplot2grid((6, 1), (4, 0), rowspan=2)

        plt.axis('off')
        the_table = plt.table(cellText=merge_df.values, rowLoc="right", rowLabels=merge_df.index, colLabels=merge_df.columns,
                              colWidths=[0.02] * len(merge_df.columns),
                              loc='best', cellLoc='center')
        the_table.auto_set_column_width(0)
        # the_table.auto_set_font_size(False)
        the_table.set_fontsize(40)
        the_table.scale(2.9, 2.9)
        if not self.show_fig:
            plt.savefig(f"{out}/report.png")
        return plt

    """
    1. 画年度收益率图
    2. 画月度收益率图
    3. 利润累计图
    4. 回撤图
    5. 交易损益图 
    6. 胜率图
    7. 盈亏比(平均盈利/平均亏损)
    8. 盈利因子(总盈利/总亏损)
    9. 预期值
    """
    def plot1(self, name, figsize=(30, 15)):

        strat = self.strats[0]
        symbol_name = self.symbol.replace("/", "_")

        ppath = f"{self.fdir}/{strat}/{symbol_name}/{name}"
        trade_df = pd.read_csv(os.path.join(ppath, "trades.csv"))
        mdd = trade_df["fundvalue"] / trade_df["fundvalue"].cummax() - 1
        # 画净值图和回撤图, 仓位图
        fig = plt.figure(figsize=figsize)
        # fig.set_size_inches(30, 15, forward=True)
        ax1 = plt.subplot2grid((9, 6), (0, 0), rowspan=4, colspan=6)
        ax1.plot(trade_df["fundvalue"]/trade_df["fundvalue"].iloc[0], 'r', label="净值")
        ax1.set_ylabel("累计净值")

        ax2 = ax1.twinx()
        ax2.plot(mdd, 'b', label="回撤")
        ax2.set_ylabel("历史回撤")
        plt.legend()
        plt.title("净值/回撤图")
        plt.grid()

        # 交易损益图
        plt.subplot2grid((9, 6), (4, 0), rowspan=3, colspan=6)
        plt.bar(x=np.arange(len(trade_df)), height=trade_df["pnlret"], width=0.5, label="pnl")
        plt.axhline(-0.02, ls="--", color="r")
        plt.axhline(-0.05, ls="--", color="r")
        plt.axhline(-0.1, ls="--", color="r")
        plt.title("交易损益(%)")
        plt.grid()

        # 月收益图
        bymonth_df = pd.read_csv(os.path.join(ppath, "bymonth.csv"))
        # byyear_df = pd.read_csv(os.path.join(ppath, "byyear.csv"))
        x = [e for e in bymonth_df["month"].values]
        plt.subplot2grid((9, 6), (7, 0), rowspan=2, colspan=5)

        bar_width = 0.3  # 条形宽度
        index_strat = np.arange(len(bymonth_df))
        index_bh = index_strat + bar_width

        plt.bar(x=index_strat, height=bymonth_df["total_ret"], width=bar_width, label="策略")
        plt.bar(x=index_bh, height=bymonth_df["b&h_ret"], width=bar_width, label="基准")
        plt.axhline(-0.1, ls="--", color="r")
        plt.axhline(-0.2, ls="--", color="r")
        plt.legend()
        plt.xticks(index_strat + bar_width / 2, x)
        plt.ylabel("月度收益率(%)")
        plt.title("月度收益率(%)")
        plt.grid()

        # 年度收益图
        byyear_df = pd.read_csv(os.path.join(ppath, "byyear.csv"))
        del byyear_df["symbol"]
        byyear_df = byyear_df.set_index("ind").T
        byyear_df = byyear_df.loc[byyear_df.index[0:-1]]
        byyear_df.columns = list(map(lambda x: x.strip(), byyear_df.columns))

        x = [e for e in byyear_df.index]
        plt.subplot2grid((9, 6), (7, 5), rowspan=2, colspan=1)

        bar_width = 0.3  # 条形宽度
        index_strat = np.arange(len(byyear_df))
        index_bh = index_strat + bar_width

        plt.bar(x=index_strat, height=byyear_df["total_ret"], width=bar_width, label="策略")
        plt.bar(x=index_bh, height=byyear_df["b&h_ret"], width=bar_width, label="基准")
        plt.legend()
        plt.xticks(index_strat + bar_width / 2, x)
        plt.ylabel("年度收益率(%)")
        plt.title("年度收益率(%)")
        plt.grid()

        plt.tight_layout()

        # bbox_inches='tight'
        plt.savefig(f"{ppath}/summary1.png")

    def plot2(self, name, figsize=(30, 15), period=15):
        if period < 15:
            period = 15

        strat = self.strats[0]
        symbol_name = self.symbol.replace("/", "_")

        ppath = f"{self.fdir}/{strat}/{symbol_name}/{name}"
        trade_df = pd.read_csv(os.path.join(ppath, "trades.csv"))

        win_rate_df = (trade_df["pnlret"] > 0).cumsum() / (trade_df["pnlret"] != 0).cumsum()
        # 最近10笔的成功率
        win_rate_df_30 = ((trade_df["pnlret"] > 0).rolling(period, min_periods=10).sum() / trade_df.rolling(period, min_periods=10)["pnlret"].count())

        # 胜率图
        fig = plt.figure(figsize=figsize)
        plt.subplot2grid((9, 6), (0, 0), rowspan=3, colspan=3)
        plt.plot(win_rate_df, label="胜率")
        plt.plot(win_rate_df_30, label=f"胜率[{period}]")
        plt.axhline(0.3, ls="--", color="r")
        plt.title("胜率(%)")
        plt.legend()
        plt.grid()

        # 平均盈亏比
        avg_profit = ((trade_df["pnlcomm"] > 0) * trade_df["pnlcomm"]).cumsum() / (trade_df["pnlcomm"] > 0).cumsum()
        avg_loss = abs(((trade_df["pnlcomm"] < 0) * trade_df["pnlcomm"]).cumsum() / (trade_df["pnlcomm"] < 0).cumsum())
        win_loss = avg_profit / avg_loss
        avg_profit_30 = ((trade_df["pnlcomm"] > 0) * trade_df["pnlcomm"]).rolling(period).sum() / (trade_df["pnlcomm"] > 0).rolling(period).sum()
        avg_loss_30 = abs(((trade_df["pnlcomm"] < 0) * trade_df["pnlcomm"]).rolling(period).sum() / (trade_df["pnlcomm"] < 0).rolling(period).sum())
        win_loss_30 = avg_profit_30 / avg_loss_30
        plt.subplot2grid((9, 6), (0, 3), rowspan=3, colspan=3)
        plt.plot(win_loss, label="盈亏比")
        plt.plot(win_loss_30, label=f"盈亏比[{period}]")
        plt.axhline(3, ls="--", color="r")
        plt.title("盈亏比(平均盈利/平均亏损)")
        plt.legend()
        plt.grid()

        # 盈利因子
        total_profit = ((trade_df["pnlcomm"] > 0) * trade_df["pnlcomm"]).cumsum()
        total_loss = abs(((trade_df["pnlcomm"] < 0) * trade_df["pnlcomm"]).cumsum())
        profit_factor = total_profit / total_loss

        total_profit_30 = ((trade_df["pnlcomm"] > 0) * trade_df["pnlcomm"]).rolling(period).sum()
        total_loss_30 = abs(((trade_df["pnlcomm"] < 0) * trade_df["pnlcomm"]).rolling(period).sum())
        profit_factor_30 = total_profit_30 / total_loss_30
        plt.subplot2grid((9, 6), (3, 0), rowspan=3, colspan=3)
        plt.plot(profit_factor, label="累积盈利因子")
        plt.plot(profit_factor_30, label=f"盈利因子[{period}]")
        plt.axhline(1, ls="--", color="r")
        plt.title("盈利因子(总盈利/总亏损)")
        plt.legend()
        plt.grid()

        # 预期汇报
        profit_expect = win_rate_df * win_loss - (1 - win_rate_df)
        profit_expect_30 = win_rate_df_30 * win_loss_30 - (1 - win_rate_df_30)
        plt.subplot2grid((9, 6), (3, 3), rowspan=3, colspan=3)
        plt.plot(profit_expect, label="预期回报")
        plt.plot(profit_expect_30, label=f"预期回报[{period}]")
        plt.axhline(0, ls="--", color="r")
        plt.title("预期回报(胜率*盈亏比-(1 - 胜率))")
        plt.legend()
        plt.grid()

        # sqn
        sqn_30 = trade_df["pnlcomm"].rolling(period).mean() / trade_df["pnlcomm"].rolling(period).std() * math.sqrt(period)
        plt.subplot2grid((9, 6), (6, 0), rowspan=3, colspan=3)
        plt.plot(sqn_30, label=f"SQN[{period}]")
        plt.axhline(0, ls="--", color="r")
        plt.title("SQN(sqrt(mean/stdev))")
        plt.legend()
        plt.grid()

        # 资金利用率
        nav_df = pd.read_csv(os.path.join(ppath, "nav.csv"))
        nav_df["pos"] = nav_df["pos"] * 100
        position = nav_df[["date", "pos"]]
        plt.subplot2grid((9, 6), (6, 3), rowspan=3, colspan=3)
        plt.plot(position.index, position["pos"], label=f"资金利用率")
        # plt.axhline(30, ls="--", color="r")
        plt.title("资金利用率")
        plt.legend()
        plt.grid()

        plt.tight_layout()
        # bbox_inches='tight'
        plt.savefig(f"{ppath}/summary2.png")


if __name__ == "__main__":
    fdir = "/Users/wudi/Workspace/project/mercury-quant/strategy/experiment/backtest"
    symbol = "BTC/USDT"
    benchmark = "00_base_strat"
    strats = ["01_emaatr_bigperiod_risk2_strat"]
    key_word = ["20190101"]
    name = "30_80_7_10_0_2_1d_10_1_95_20000_20181201_20220214"
    # ShowPlot(fdir, symbol, benchmark, strats, show_fig=False).show(name=None, figsize=(25, 10), key_word=key_word)

    ShowPlot(fdir, symbol, benchmark, strats, show_fig=False).plot1(name=name)
    ShowPlot(fdir, symbol, benchmark, strats, show_fig=False).plot2(name=name)


